{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
{%- import "../util/util.jinja" as util with context -%}

# ruff: noqa: F401, PLR0912, PLR0913, PLR0914, PLR0915, PLR0917, RUF100

import math
import typing as T

import torch


class TensorKwargs(T.TypedDict):
    """
    TypedDict representing args that will be passed to any torch.tensor calls
    """
    device: torch.device
    dtype: torch.dtype


def _broadcast_and_stack(tensors, dim=-1):
    # type: (T.List[torch.Tensor], int) -> torch.Tensor
    """
    broadcast tensors to common shape then stack along new dimension
    """

    broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
    broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

    return torch.stack(broadcast_tensors, dim=dim)

{{ util.function_declaration(spec) }}
    {% if spec.docstring %}
    {{ util.print_docstring(spec.docstring) | indent(4) }}
    {% endif %}

    {{ util.expr_code(spec) }}
